{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Uncertainty Quantification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/uncertainty.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "\n",
    "from chemprop import data, models, nn, uncertainty\n",
    "from chemprop.models import save_model, load_model\n",
    "from chemprop.cli.conf import NOW\n",
    "from chemprop.cli.predict import find_models\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loda data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = (\n",
    "    chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\"\n",
    ")  # path to your data .csv file\n",
    "df_input = pd.read_csv(input_path)\n",
    "smis = df_input.loc[:, \"smiles\"].values\n",
    "ys = df_input.loc[:, [\"lipo\"]].values\n",
    "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset = data.MoleculeDataset(train_data[0])\n",
    "scaler = train_dset.normalize_targets()\n",
    "\n",
    "val_dset = data.MoleculeDataset(val_data[0])\n",
    "val_dset.normalize_targets(scaler)\n",
    "\n",
    "test_dset = data.MoleculeDataset(test_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = data.build_dataloader(train_dset)\n",
    "val_loader = data.build_dataloader(val_dset, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructs MPNN\n",
    "\n",
    "- A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.\n",
    "\n",
    "- An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.\n",
    "\n",
    "- A `FFN` takes the aggregated representations and make target predictions. To obtain uncertainty predictions, the `FFN` must be modified accordingly.\n",
    "\n",
    "    For regression:\n",
    "    - `ffn = nn.RegressionFFN()`\n",
    "    - `ffn = nn.MveFFN()`\n",
    "    - `ffn = nn.EvidentialFFN()`\n",
    "\n",
    "    For classification:\n",
    "    - `ffn = nn.BinaryClassificationFFN()`\n",
    "    - `ffn = nn.BinaryDirichletFFN()`\n",
    "    - `ffn = nn.MulticlassClassificationFFN()`\n",
    "    - `ffn = nn.MulticlassDirichletFFN()`\n",
    "\n",
    "    For spectral:\n",
    "    - `ffn = nn.SpectralFFN()` # will be available in future version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): Identity()\n",
       "  (predictor): MveFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=2, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MVELoss(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0): MSE(task_weights=[[1.0]])\n",
       "    (1): MVELoss(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mp = nn.BondMessagePassing()\n",
    "agg = nn.MeanAggregation()\n",
    "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n",
    "# Change to other predictor if needed.\n",
    "ffn = nn.MveFFN(output_transform=output_transform)\n",
    "mpnn = models.MPNN(mp, agg, ffn, batch_norm=False)\n",
    "mpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_output_dir = Path(f\"chemprop_training/{NOW}\")\n",
    "monitor_mode = \"min\" if mpnn.metrics[0].higher_is_better else \"max\"\n",
    "checkpointing = ModelCheckpoint(\n",
    "    model_output_dir / \"checkpoints\",\n",
    "    \"best-{epoch}-{val_loss:.2f}\",\n",
    "    \"val_loss\",\n",
    "    mode=monitor_mode,\n",
    "    save_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_checkpointing=True,\n",
    "    enable_progress_bar=False,\n",
    "    accelerator=\"cpu\",\n",
    "    callbacks=[checkpointing],\n",
    "    devices=1,\n",
    "    max_epochs=20,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 227 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | Identity           | 0      | train\n",
      "3 | predictor       | MveFFN             | 90.9 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "318 K     Trainable params\n",
      "0         Non-trainable params\n",
      "318 K     Total params\n",
      "1.274     Total estimated model params size (MB)\n",
      "24        Modules in train mode\n",
      "0         Modules in eval mode\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(mpnn, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_model_path = checkpointing.best_model_path\n",
    "model = mpnn.__class__.load_from_checkpoint(best_model_path)\n",
    "p_model = model_output_dir / \"best.pt\"\n",
    "save_model(p_model, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predicting"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Change model input here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>lipo</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14</td>\n",
       "      <td>3.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...</td>\n",
       "      <td>-1.18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl</td>\n",
       "      <td>3.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...</td>\n",
       "      <td>3.37</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...</td>\n",
       "      <td>3.10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...</td>\n",
       "      <td>2.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...</td>\n",
       "      <td>2.04</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...</td>\n",
       "      <td>4.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>COc1ccc(Cc2c(N)n[nH]c2N)cc1</td>\n",
       "      <td>0.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...</td>\n",
       "      <td>2.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles  lipo\n",
       "0             Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14  3.54\n",
       "1   COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
       "2              COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl  3.69\n",
       "3   OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...  3.37\n",
       "4   Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...  3.10\n",
       "..                                                ...   ...\n",
       "95  CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...  2.20\n",
       "96  CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...  2.04\n",
       "97  CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...  4.49\n",
       "98                        COc1ccc(Cc2c(N)n[nH]c2N)cc1  0.20\n",
       "99  CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...  2.00\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\"\n",
    "df_test = pd.read_csv(test_path)\n",
    "test_dset = data.MoleculeDataset(test_data[0])\n",
    "test_loader = data.build_dataloader(test_dset, shuffle=False)\n",
    "df_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use the validation set from the training as the calibration set as an example\n",
    "cal_dset = data.MoleculeDataset(val_data[0])\n",
    "cal_loader = data.build_dataloader(cal_dset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructs uncertainty estimator\n",
    "An uncertianty estimator can make model predictions and associated uncertainty predictions.\n",
    "\n",
    "Available options can be found in `uncertainty.UncertaintyEstimatorRegistry`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'none': <class 'chemprop.uncertainty.estimator.NoUncertaintyEstimator'>,\n",
      "    'mve': <class 'chemprop.uncertainty.estimator.MVEEstimator'>,\n",
      "    'ensemble': <class 'chemprop.uncertainty.estimator.EnsembleEstimator'>,\n",
      "    'classification': <class 'chemprop.uncertainty.estimator.ClassEstimator'>,\n",
      "    'evidential-total': <class 'chemprop.uncertainty.estimator.EvidentialTotalEstimator'>,\n",
      "    'evidential-epistemic': <class 'chemprop.uncertainty.estimator.EvidentialEpistemicEstimator'>,\n",
      "    'evidential-aleatoric': <class 'chemprop.uncertainty.estimator.EvidentialAleatoricEstimator'>,\n",
      "    'dropout': <class 'chemprop.uncertainty.estimator.DropoutEstimator'>,\n",
      "    'classification-dirichlet': <class 'chemprop.uncertainty.estimator.ClassificationDirichletEstimator'>,\n",
      "    'multiclass-dirichlet': <class 'chemprop.uncertainty.estimator.MulticlassDirichletEstimator'>,\n",
      "    'quantile-regression': <class 'chemprop.uncertainty.estimator.QuantileRegressionEstimator'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(uncertainty.UncertaintyEstimatorRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "unc_estimator = uncertainty.MVEEstimator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructs uncertainty calibrator\n",
    "An uncertianty calibrator can calibrate the predicted uncertainties.\n",
    "\n",
    "Available options can be found in `uncertainty.UncertaintyCalibratorRegistry`.\n",
    "\n",
    "For regression:\n",
    "\n",
    "- ZScalingCalibrator\n",
    "\n",
    "- ZelikmanCalibrator\n",
    "\n",
    "- MVEWeightingCalibrator\n",
    "\n",
    "- RegressionConformalCalibrator\n",
    "\n",
    "For binary classification:\n",
    "\n",
    "- PlattCalibrator\n",
    "\n",
    "- IsotonicCalibrator\n",
    "\n",
    "- MultilabelConformalCalibrator\n",
    "\n",
    "For multiclass classification:\n",
    "\n",
    "- MulticlassConformalCalibrator\n",
    "\n",
    "- AdaptiveMulticlassConformalCalibrator\n",
    "\n",
    "- IsotonicMulticlassCalibrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'zscaling': <class 'chemprop.uncertainty.calibrator.ZScalingCalibrator'>,\n",
      "    'zelikman-interval': <class 'chemprop.uncertainty.calibrator.ZelikmanCalibrator'>,\n",
      "    'mve-weighting': <class 'chemprop.uncertainty.calibrator.MVEWeightingCalibrator'>,\n",
      "    'conformal-regression': <class 'chemprop.uncertainty.calibrator.RegressionConformalCalibrator'>,\n",
      "    'platt': <class 'chemprop.uncertainty.calibrator.PlattCalibrator'>,\n",
      "    'isotonic': <class 'chemprop.uncertainty.calibrator.IsotonicCalibrator'>,\n",
      "    'conformal-multilabel': <class 'chemprop.uncertainty.calibrator.MultilabelConformalCalibrator'>,\n",
      "    'conformal-multiclass': <class 'chemprop.uncertainty.calibrator.MulticlassConformalCalibrator'>,\n",
      "    'conformal-adaptive': <class 'chemprop.uncertainty.calibrator.AdaptiveMulticlassConformalCalibrator'>,\n",
      "    'isotonic-multiclass': <class 'chemprop.uncertainty.calibrator.IsotonicMulticlassCalibrator'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(uncertainty.UncertaintyCalibratorRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "unc_calibrator = uncertainty.ZScalingCalibrator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constructs uncertainty evaluator\n",
    "An uncertianty evaluator can evaluates the quality of uncertainty estimates.\n",
    "\n",
    "Available options can be found in `uncertainty.UncertaintyEvaluatorRegistry`.\n",
    "\n",
    "For regression:\n",
    "\n",
    "- NLLRegressionEvaluator\n",
    "\n",
    "- CalibrationAreaEvaluator\n",
    "\n",
    "- ExpectedNormalizedErrorEvaluator\n",
    "\n",
    "- SpearmanEvaluator\n",
    "\n",
    "- RegressionConformalEvaluator\n",
    "\n",
    "For binary classification:\n",
    "\n",
    "- NLLClassEvaluator\n",
    "\n",
    "- MultilabelConformalEvaluator\n",
    "\n",
    "\n",
    "For multiclass classification:\n",
    "\n",
    "- NLLMulticlassEvaluator\n",
    "\n",
    "- MulticlassConformalEvaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'nll-regression': <class 'chemprop.uncertainty.evaluator.NLLRegressionEvaluator'>,\n",
      "    'miscalibration_area': <class 'chemprop.uncertainty.evaluator.CalibrationAreaEvaluator'>,\n",
      "    'ence': <class 'chemprop.uncertainty.evaluator.ExpectedNormalizedErrorEvaluator'>,\n",
      "    'spearman': <class 'chemprop.uncertainty.evaluator.SpearmanEvaluator'>,\n",
      "    'conformal-coverage-regression': <class 'chemprop.uncertainty.evaluator.RegressionConformalEvaluator'>,\n",
      "    'nll-classification': <class 'chemprop.uncertainty.evaluator.NLLClassEvaluator'>,\n",
      "    'conformal-coverage-classification': <class 'chemprop.uncertainty.evaluator.MultilabelConformalEvaluator'>,\n",
      "    'nll-multiclass': <class 'chemprop.uncertainty.evaluator.NLLMulticlassEvaluator'>,\n",
      "    'conformal-coverage-multiclass': <class 'chemprop.uncertainty.evaluator.MulticlassConformalEvaluator'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(uncertainty.UncertaintyEvaluatorRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "unc_evaluators = [\n",
    "    uncertainty.NLLRegressionEvaluator(),\n",
    "    uncertainty.CalibrationAreaEvaluator(),\n",
    "    uncertainty.ExpectedNormalizedErrorEvaluator(),\n",
    "    uncertainty.SpearmanEvaluator(),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_paths = find_models([model_output_dir])\n",
    "models = [load_model(model_path, multicomponent=False) for model_path in model_paths]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(logger=False, enable_progress_bar=True, accelerator=\"cpu\", devices=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make uncertainty estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 126.93it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>target</th>\n",
       "      <th>pred</th>\n",
       "      <th>unc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3...</td>\n",
       "      <td>2.06</td>\n",
       "      <td>2.047474</td>\n",
       "      <td>1.543233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C...</td>\n",
       "      <td>1.92</td>\n",
       "      <td>2.047561</td>\n",
       "      <td>1.534631</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CNCCCC12CCC(c3ccccc31)c1ccccc12</td>\n",
       "      <td>0.89</td>\n",
       "      <td>2.062057</td>\n",
       "      <td>1.548673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Oc1ncnc2scc(-c3ccsc3)c12</td>\n",
       "      <td>2.25</td>\n",
       "      <td>2.061813</td>\n",
       "      <td>1.555989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3...</td>\n",
       "      <td>2.04</td>\n",
       "      <td>2.038238</td>\n",
       "      <td>1.532385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1</td>\n",
       "      <td>3.13</td>\n",
       "      <td>2.048835</td>\n",
       "      <td>1.535416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>O=C(COc1ccccc1)c1ccccc1</td>\n",
       "      <td>2.87</td>\n",
       "      <td>2.066844</td>\n",
       "      <td>1.534430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1</td>\n",
       "      <td>1.10</td>\n",
       "      <td>2.053771</td>\n",
       "      <td>1.550390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1</td>\n",
       "      <td>-0.16</td>\n",
       "      <td>2.047554</td>\n",
       "      <td>1.535353</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>COc1cnc(-c2ccccn2)nc1N(C)C</td>\n",
       "      <td>1.90</td>\n",
       "      <td>2.050501</td>\n",
       "      <td>1.537318</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              smiles  target      pred  \\\n",
       "0  Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3...    2.06  2.047474   \n",
       "1  O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C...    1.92  2.047561   \n",
       "2                    CNCCCC12CCC(c3ccccc31)c1ccccc12    0.89  2.062057   \n",
       "3                           Oc1ncnc2scc(-c3ccsc3)c12    2.25  2.061813   \n",
       "4  C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3...    2.04  2.038238   \n",
       "5      COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1    3.13  2.048835   \n",
       "6                            O=C(COc1ccccc1)c1ccccc1    2.87  2.066844   \n",
       "7           CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1    1.10  2.053771   \n",
       "8        N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1   -0.16  2.047554   \n",
       "9                         COc1cnc(-c2ccccn2)nc1N(C)C    1.90  2.050501   \n",
       "\n",
       "        unc  \n",
       "0  1.543233  \n",
       "1  1.534631  \n",
       "2  1.548673  \n",
       "3  1.555989  \n",
       "4  1.532385  \n",
       "5  1.535416  \n",
       "6  1.534430  \n",
       "7  1.550390  \n",
       "8  1.535353  \n",
       "9  1.537318  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_predss, test_uncss = unc_estimator(test_loader, models, trainer)\n",
    "test_preds = test_predss.mean(0)\n",
    "test_uncs = test_uncss.mean(0)\n",
    "\n",
    "df_test = pd.DataFrame(\n",
    "    {\n",
    "        \"smiles\": test_dset.smiles,\n",
    "        \"target\": test_dset.Y.reshape(-1),\n",
    "        \"pred\": test_preds.reshape(-1),\n",
    "        \"unc\": test_uncs.reshape(-1),\n",
    "    }\n",
    ")\n",
    "\n",
    "df_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply uncertainty calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 228.26it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>target</th>\n",
       "      <th>pred</th>\n",
       "      <th>unc</th>\n",
       "      <th>cal_unc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3...</td>\n",
       "      <td>2.06</td>\n",
       "      <td>2.047474</td>\n",
       "      <td>1.543233</td>\n",
       "      <td>1.691122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C...</td>\n",
       "      <td>1.92</td>\n",
       "      <td>2.047561</td>\n",
       "      <td>1.534631</td>\n",
       "      <td>1.681696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CNCCCC12CCC(c3ccccc31)c1ccccc12</td>\n",
       "      <td>0.89</td>\n",
       "      <td>2.062057</td>\n",
       "      <td>1.548673</td>\n",
       "      <td>1.697084</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Oc1ncnc2scc(-c3ccsc3)c12</td>\n",
       "      <td>2.25</td>\n",
       "      <td>2.061813</td>\n",
       "      <td>1.555989</td>\n",
       "      <td>1.705101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3...</td>\n",
       "      <td>2.04</td>\n",
       "      <td>2.038238</td>\n",
       "      <td>1.532385</td>\n",
       "      <td>1.679235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1</td>\n",
       "      <td>3.13</td>\n",
       "      <td>2.048835</td>\n",
       "      <td>1.535416</td>\n",
       "      <td>1.682556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>O=C(COc1ccccc1)c1ccccc1</td>\n",
       "      <td>2.87</td>\n",
       "      <td>2.066844</td>\n",
       "      <td>1.534430</td>\n",
       "      <td>1.681475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1</td>\n",
       "      <td>1.10</td>\n",
       "      <td>2.053771</td>\n",
       "      <td>1.550390</td>\n",
       "      <td>1.698965</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1</td>\n",
       "      <td>-0.16</td>\n",
       "      <td>2.047554</td>\n",
       "      <td>1.535353</td>\n",
       "      <td>1.682488</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>COc1cnc(-c2ccccn2)nc1N(C)C</td>\n",
       "      <td>1.90</td>\n",
       "      <td>2.050501</td>\n",
       "      <td>1.537318</td>\n",
       "      <td>1.684641</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              smiles  target      pred  \\\n",
       "0  Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3...    2.06  2.047474   \n",
       "1  O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C...    1.92  2.047561   \n",
       "2                    CNCCCC12CCC(c3ccccc31)c1ccccc12    0.89  2.062057   \n",
       "3                           Oc1ncnc2scc(-c3ccsc3)c12    2.25  2.061813   \n",
       "4  C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3...    2.04  2.038238   \n",
       "5      COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1    3.13  2.048835   \n",
       "6                            O=C(COc1ccccc1)c1ccccc1    2.87  2.066844   \n",
       "7           CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1    1.10  2.053771   \n",
       "8        N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1   -0.16  2.047554   \n",
       "9                         COc1cnc(-c2ccccn2)nc1N(C)C    1.90  2.050501   \n",
       "\n",
       "        unc   cal_unc  \n",
       "0  1.543233  1.691122  \n",
       "1  1.534631  1.681696  \n",
       "2  1.548673  1.697084  \n",
       "3  1.555989  1.705101  \n",
       "4  1.532385  1.679235  \n",
       "5  1.535416  1.682556  \n",
       "6  1.534430  1.681475  \n",
       "7  1.550390  1.698965  \n",
       "8  1.535353  1.682488  \n",
       "9  1.537318  1.684641  "
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cal_predss, cal_uncss = unc_estimator(cal_loader, models, trainer)\n",
    "average_cal_preds = cal_predss.mean(0)\n",
    "average_cal_uncs = cal_uncss.mean(0)\n",
    "cal_targets = cal_dset.Y\n",
    "cal_mask = torch.from_numpy(np.isfinite(cal_targets))\n",
    "cal_targets = np.nan_to_num(cal_targets, nan=0.0)\n",
    "cal_targets = torch.from_numpy(cal_targets)\n",
    "unc_calibrator.fit(average_cal_preds, average_cal_uncs, cal_targets, cal_mask)\n",
    "\n",
    "cal_test_uncs = unc_calibrator.apply(test_uncs)\n",
    "df_test[\"cal_unc\"] = cal_test_uncs\n",
    "df_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate predicted uncertainty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nll-regression: [1.4490190356267003]\n",
      "miscalibration_area: [0.15619999170303345]\n",
      "ence: [0.6248166925739804]\n",
      "spearman: [0.27272725105285645]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
      "  warnings.warn(*args, **kwargs)  # noqa: B028\n"
     ]
    }
   ],
   "source": [
    "test_targets = test_dset.Y\n",
    "test_mask = torch.from_numpy(np.isfinite(test_targets))\n",
    "test_targets = np.nan_to_num(test_targets, nan=0.0)\n",
    "test_targets = torch.from_numpy(test_targets)\n",
    "\n",
    "for evaluator in unc_evaluators:\n",
    "    evaluation = evaluator.evaluate(test_preds, cal_test_uncs, test_targets, test_mask)\n",
    "    print(f\"{evaluator.alias}: {evaluation.tolist()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
